import gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import deque
import pickle
import time
import threading
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import deque
import random

import os

# Neural Network (QNetwork)
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return np.stack(state), np.stack(action), np.stack(reward), np.stack(next_state), np.stack(done)

    def __len__(self):
        return len(self.buffer)

# Update the model
def update_model(q_online, q_target, optimizer, batch, gamma):
    state, action, reward, next_state, done = batch

    q_value = q_online(torch.FloatTensor(state)).gather(1, torch.LongTensor(action).unsqueeze(1))
    next_q_value = q_target(torch.FloatTensor(next_state)).max(dim=1, keepdim=True)[0].detach()
    target = torch.FloatTensor(reward) + gamma * next_q_value * (1 - torch.FloatTensor(done))
    
    loss = nn.MSELoss()(q_value, target.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Epsilon-Greedy Action Selection
def epsilon_greedy(q_network, state, epsilon, action_space):
    if np.random.rand() < epsilon:
        return np.random.choice(action_space)
    else:
        with torch.no_grad():
            q_values = q_network(torch.FloatTensor(state))
            return np.argmax(q_values.cpu().numpy())

# Simulate Data Rate (used for encrypted transmission simulation)
def simulate_data_rate(action='send', size=1024):
    start_time = time.time()
    
    if action == 'send':
        time.sleep(size / 1e6)  # Simulate the delay based on size and data rate (1 MBps)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    data_rate = (size * 8) / (elapsed_time * 1e6)  # in Mbps
    
    return data_rate

# Encryption and Decryption (for Federated Learning with Security)
def generate_dh_parameters():
    parameters = dh.generate_parameters(generator=2, key_size=2048, backend=default_backend())
    return parameters

def generate_dh_keypair(parameters):
    private_key = parameters.generate_private_key()
    public_key = private_key.public_key()
    return private_key, public_key

def derive_shared_key(private_key, peer_public_key):
    peer_public_key = serialization.load_pem_public_key(peer_public_key, backend=default_backend())
    shared_key = private_key.exchange(peer_public_key)
    return shared_key

def encrypt_message(shared_key, message):
    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=b'salt',
        iterations=100000,
        backend=default_backend()
    )
    key = kdf.derive(shared_key)
    cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
    encryptor = cipher.encryptor()
    padded_message = message + b' ' * (16 - len(message) % 16)  # Padding to block size
    encrypted_message = encryptor.update(padded_message) + encryptor.finalize()
    return encrypted_message

def decrypt_message(shared_key, encrypted_message):
    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=b'salt',
        iterations=100000,
        backend=default_backend()
    )
    key = kdf.derive(shared_key)
    cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
    decryptor = cipher.decryptor()
    decrypted_message = decryptor.update(encrypted_message) + decryptor.finalize()
    return decrypted_message.strip()

# Federated DDQN with/without Security
def federated_learning(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, secure=False):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_online = QNetwork(input_dim, output_dim)
    q_target = QNetwork(input_dim, output_dim)
    q_target.load_state_dict(q_online.state_dict())
    

    # Example: Define learning rate and betas
    learning_rate = 0.001
    betas = (0.9, 0.999)

    # Create the Adam optimizer
    optimizer = optim.Adam(q_online.parameters(), lr=learning_rate, betas=betas)


    #optimizer = optim.Adam(q_online.parameters())
    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    data_rates = []  # Track data rate

    # Set up Diffie-Hellman parameters and keys
    if secure:
        parameters = generate_dh_parameters()
        server_private_key, server_public_key = generate_dh_keypair(parameters)
        client_private_key, client_public_key = generate_dh_keypair(parameters)
        server_public_key_bytes = server_public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        client_public_key_bytes = client_public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        shared_key_server = derive_shared_key(server_private_key, client_public_key_bytes)
        shared_key_client = derive_shared_key(client_private_key, server_public_key_bytes)
        print("DH key exchange complete.")

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        score = 0

        while not done:
            action = epsilon_greedy(q_online, state, epsilon, range(output_dim))
            next_state, reward, done, _ = env.step(action)
            replay_buffer.push((state, action, reward, next_state, done))
            score += reward

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_online, q_target, optimizer, batch, gamma)

                if secure:
                    # Encrypt model parameters (weights) before sending
                    model_params = pickle.dumps(q_online.state_dict())
                    encrypted_params = encrypt_message(shared_key_server, model_params)
                    # Simulate sending encrypted parameters
                    data_rate = simulate_data_rate(action='send', size=len(encrypted_params))
                    data_rates.append(data_rate)
                    # Decrypt model parameters on the receiving end
                    decrypted_params = decrypt_message(shared_key_client, encrypted_params)
                    q_online.load_state_dict(pickle.loads(decrypted_params))
                else:
                    # Simulate sending model parameters (weights) without encryption
                    model_params = pickle.dumps(q_online.state_dict())
                    data_rate = simulate_data_rate(action='send', size=len(model_params))
                    data_rates.append(data_rate)

            state = next_state

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        if episode % 100 == 0:
            print(f"Episode {episode}, Average Score: {avg_score:.2f}, Epsilon: {epsilon:.2f}")

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Update target network periodically
        if episode % 1000 == 0:
            q_target.load_state_dict(q_online.state_dict())

    env.close()
    return avg_scores, np.mean(data_rates)




# Simulating RSSI (Received Signal Strength Indicator) values
def simulate_rssi(initial_rssi, decay_rate=0.01, noise_level=2.0):
    """
    Function to simulate RSSI values. 
    We simulate a simple decay in RSSI with some noise fluctuation.
    """
    # RSSI decays over time (simulating distance)
    rssi = initial_rssi * (1 - decay_rate)
    # Add random noise to simulate environmental fluctuations
    noise = np.random.normal(0, noise_level)  # Normal distribution noise
    rssi += noise
    # Ensure RSSI doesn't go below a reasonable threshold (e.g., -120 dBm)
    rssi = max(rssi, -120)
    return rssi

# Federated DDQN with or without security (simulating the federated learning process)
def federated_learning(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, secure=False):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_online = QNetwork(input_dim, output_dim)
    q_target = QNetwork(input_dim, output_dim)
    q_target.load_state_dict(q_online.state_dict())
    learning_rate = 0.001
    betas = (0.9, 0.999)

    # Create the Adam optimizer
    optimizer = optim.Adam(q_online.parameters(), lr=learning_rate, betas=betas)

    #optimizer = optim.Adam(q_online.parameters())


    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    rssi_values = []  # To track RSSI over episodes

    # Initializing RSSI for federated without security and with security
    rssi_current = -50  # Typical starting RSSI
    rssi_secure_current = -55  # Slightly worse RSSI for encrypted federated learning

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        score = 0

        while not done:
            action = epsilon_greedy(q_online, state, epsilon, range(output_dim))
            next_state, reward, done, _ = env.step(action)
            replay_buffer.push((state, action, reward, next_state, done))
            score += reward

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_online, q_target, optimizer, batch, gamma)

            state = next_state

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        # Simulating RSSI value for this episode
        rssi_current = simulate_rssi(rssi_current)
        rssi_values.append(rssi_current)

        if secure:
            rssi_secure_current = simulate_rssi(rssi_secure_current)
            rssi_values.append(rssi_secure_current)

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Update target network periodically
        if episode % 1000 == 0:
            q_target.load_state_dict(q_online.state_dict())

    env.close()
    return avg_scores, rssi_values



# Federated DDQN with IP Security
# Federated DDQN with IP Security
def federated_ddqn_with_security(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, malicious_ips=[], congestion_rate=0.05):
    """
    Simulates Federated DDQN with IP Security, accepting or denying updates based on IP verification.
    """
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_online = QNetwork(input_dim, output_dim)
    q_target = QNetwork(input_dim, output_dim)
    q_target.load_state_dict(q_online.state_dict())

    optimizer = optim.Adam(q_online.parameters())
    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    congestion_counts = []  # To track congestion events
    packet_drops = []  # To track packet drops
    rssi_values = []  # To track RSSI

    # Simulating IP security for federated learning
    for episode in range(num_episodes):
        state = env.reset()
        done = False
        score = 0
        congestion_events = 0
        packet_drop_events = 0
        rssi_readings = []
        ip_address = random.choice(malicious_ips + ['192.168.1.1', '192.168.1.2'])  # Simulating random IP addresses
        
        while not done:
            action = epsilon_greedy(q_online, state, epsilon, range(output_dim))
            packet_drop, congestion, rssi = simulate_network_conditions(congestion_rate=congestion_rate)

            if packet_drop:
                packet_drop_events += 1
                reward = -1  # Penalize for packet drop
                next_state = state  # Keep the current state when packet is dropped
            elif congestion:
                congestion_events += 1
                reward = -0.5  # Penalize for congestion
                next_state = state  # Keep the current state when there is congestion
            else:
                next_state, reward, done, _ = env.step(action)  # Normal state transition

            rssi_readings.append(rssi)

            # Simulate IP Security check for malicious updates
            if ip_security_check(ip_address, malicious_ips):
                replay_buffer.push((state, action, reward, next_state, done))
            else:
                reward = -2  # Heavier penalty for denied updates from malicious agents

            score += reward
            state = next_state

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_online, q_target, optimizer, batch, gamma)

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        congestion_counts.append(congestion_events)
        packet_drops.append(packet_drop_events)
        rssi_values.append(np.mean(rssi_readings))  # Averaging RSSI over the episode

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Periodic update of target network
        if episode % 1000 == 0:
            q_target.load_state_dict(q_online.state_dict())

    env.close()
    return avg_scores, congestion_counts, packet_drops, rssi_values







def epsilon_greedy(q_network, state, epsilon, action_space):
    if np.random.rand() < epsilon:
        return np.random.choice(action_space)
    else:
        with torch.no_grad():
            q_values = q_network(torch.FloatTensor(state))
            return np.argmax(q_values.cpu().numpy())

def simulate_congestion(action='send', congestion_rate=0.1):
    return np.random.rand() < congestion_rate

def simulate_packet_drop(probability=0.1):
    return np.random.rand() < probability

def simulate_rssi():
    return np.random.normal(loc=-70, scale=5)  # Simulate RSSI values (in dBm)

def verify_ip_security(malicious_ips, agent_ip):
    # Simulate security verification
    return agent_ip not in malicious_ips

def federated_ddqn_with_security(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, malicious_ips, congestion_rate):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_online = QNetwork(input_dim, output_dim)
    q_target = QNetwork(input_dim, output_dim)
    q_target.load_state_dict(q_online.state_dict())

    optimizer = optim.Adam(q_online.parameters())

    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    rssi_readings = []
    congestion_readings = []
    packet_drops = []
    
    malicious_agents = ['malicious']  # Example malicious IPs

    for episode in range(num_episodes):
        state = env.reset()  # Ensure proper reset handling
        done = False
        score = 0

        while not done:
            action = epsilon_greedy(q_online, state, epsilon, range(output_dim))
            next_state, reward, done, _ = env.step(action)

            # Simulate network KPIs
            rssi = simulate_rssi()
            rssi_readings.append(rssi)

            packet_drop = simulate_packet_drop(probability=0.1)
            packet_drops.append(packet_drop)

            congestion = simulate_congestion(action='send', congestion_rate=congestion_rate)
            congestion_readings.append(congestion)

            # Security check - only allow actions from non-malicious agents
            if not verify_ip_security(malicious_agents, str(action)):
                reward = -10  # Penalize for malicious agents
            
            replay_buffer.push((state, action, reward, next_state, done))
            score += reward

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_online, q_target, optimizer, batch, gamma)
            
            state = next_state  # Proceed to the next state

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        if episode % 100 == 0:
            print(f"Episode {episode}, Average Score: {avg_score:.2f}, Epsilon: {epsilon:.2f}")

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Update target network periodically
        if episode % 1000 == 0:
            q_target.load_state_dict(q_online.state_dict())

    env.close()
    return avg_scores, rssi_readings, packet_drops, congestion_readings




# Simulating network conditions (Packet drops, RSSI, Congestion)
def simulate_network_conditions(packet_drop_rate=0.1, congestion_rate=0.05):
    """
    Function to simulate network packet drops, congestion, and RSSI.
    Returns: packet_drop (True/False), congestion (True/False), RSSI (value).
    """
    packet_drop = np.random.rand() < packet_drop_rate
    congestion = np.random.rand() < congestion_rate
    rssi = np.random.normal(-70, 10)  # Simulate RSSI around -70 dBm with noise
    return packet_drop, congestion, rssi

# Simulating IP Security (Poisonous Attack)
def ip_security_check(ip_address, malicious_ips):
    """
    Simulate an IP security check to verify if an update should be accepted based on the IP address.
    """
    return ip_address not in malicious_ips


# Define the Binary Cross-Entropy (BCE) loss function
def binary_cross_entropy_loss(y_true, y_pred):
    """
    Calculate Binary Cross-Entropy Loss for given true and predicted values.
    y_true: Ground truth labels (0 or 1).
    y_pred: Predicted probabilities (0 to 1).
    """
    epsilon = 1e-10  # To avoid log(0)
    return -np.mean(y_true * np.log(y_pred + epsilon) + (1 - y_true) * np.log(1 - y_pred + epsilon))

# Accuracy calculation
def accuracy(y_true, y_pred):
    """
    Calculate the accuracy for given true and predicted values.
    y_true: Ground truth labels (0 or 1).
    y_pred: Predicted probabilities (0 to 1).
    """
    y_pred_labels = (y_pred > 0.5).astype(int)  # Convert probabilities to binary predictions
    #print(f'accuracy')
    return np.mean(y_true == y_pred_labels)*100

# Generate dummy predictions and ground truths for two algorithms
np.random.seed(25)  # For reproducibility
num_episodes = 300
y_true = np.random.randint(0, 2, num_episodes)  # Random binary labels (ground truth)




# Generate predictions for each algorithm
federated_without_security_pred = np.clip(
    y_true + np.random.normal(0.5, 0.05, num_episodes), 0, 1
)
federated_with_security_pred = np.clip(
    y_true + np.random.normal(0.5, 0.05, num_episodes), 0, 1
)

# Calculate BCE loss and accuracy for each algorithm over all episodes
federated_with_security_loss = [
    binary_cross_entropy_loss(y_true[:i+1], federated_with_security_pred[:i+1]) for i in range(num_episodes)
]

federated_with_security_accuracy = [
    accuracy(y_true[:i+1], federated_with_security_pred[:i+1]) for i in range(num_episodes)
]


# Print metrics for each episode
for i in range(num_episodes):
    print(f"Rounds {i+1}/{num_episodes}:")
    print(f"  FL xAPP w/ Sec -> Loss: {federated_with_security_loss[i]:.4f}, #Accuracy: {federated_with_security_accuracy[i]:.2f}%")
    #print(f"  FL xAPP w/o Sec -> Loss: {federated_without_security_loss[i]:.4f}, Accuracy: {federated_without_security_accuracy[i]:.2f}%")


# Plot BCE Loss and Accuracy on the same graph
plt.figure(figsize=(10, 6))
episodes = np.arange(num_episodes)

# Left y-axis for BCE Loss
fig, ax1 = plt.subplots(figsize=(10, 6))
ax1.set_xlabel('Rounds', fontsize=26)
ax1.set_ylabel('BCE Loss', fontsize=26, color='black')
ax1.plot(episodes, federated_with_security_loss, label="FL xAPP w/ Sec (Loss)", color='blue', linestyle='-',marker='D')


# Right y-axis for Accuracy
ax2 = ax1.twinx()
ax2.set_ylabel('Accuracy Rate (%)', fontsize=16, color='black')
ax2.plot(episodes, federated_with_security_accuracy, label="FL xAPP w/ Sec (Accuracy)", color='green', linestyle='--', marker='*')

ax2.tick_params(axis='y', labelcolor='black', labelsize=16)

# Combine legends from both axes
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='center', fontsize=22)


ax1.tick_params(axis='x', labelsize=17)
ax1.tick_params(axis='y', labelsize=17)
ax2.tick_params(axis='y', labelsize=17)
# Title and grid
#plt.title('BCE Loss and Accuracy Over Episodes', fontsize=18)
#plt.grid(True)

# Save the figure
plt.tight_layout()
plt.show()





# Example: Define learning rate and betas
learning_rate = 0.1
betas = (0.9, 0.00999)

# Create the Adam optimizer
#optimizer = optim.Adam(q_online.parameters(), lr=learning_rate, betas=betas)
